/-
Copyright (c) 2023 Scott Morrison. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Scott Morrison
-/
import Mathlib.Data.ListM.Basic

/-!
# Best first search

We perform best first search of a tree or graph,
where the neighbours of a vertex are provided by a lazy list `α → ListM m α`.

We maintain a priority queue of visited-but-not-exhausted nodes,
and at each step take the next child of the highest priority node in the queue.

This is useful in meta code for searching for solutions in the presence of alternatives.
It can be nice to represent the choices via a lazy list,
so the later choices don't need to be evaluated while we do depth first search on earlier choices.

Options:
* `maxDepth` allows bounding the search depth
* `maxQueued` implements "beam" search,
  by discarding elements from the priority queue when it grows too large
* `removeDuplicates` maintains an `RBSet` of previously visited nodes;
  otherwise if the graph is not a tree nodes may be visited multiple times.
-/


variable {α : Type u} [Monad m] [Alternative m] [Ord α]

open Std ListM

/--
Auxiliary function for `bestFirstSearch`, that updates the internal state,
consisting of a priority queue of triples `α × Nat × ListM m α`.
We remove the next element from the list contained in the best triple
(discarding the triple if there is no next element),
enqueue it and return it.
-/
-- The return type has `× List α` rather than just `× Option α` so `bestFirstSearch` can use `fixl`.
unsafe def bestFirstSearchAux
    (f : Nat → α → ListM m α) (maxQueued : Option Nat := none) :
    RBMap α (Nat × ListM m α) compare → m (RBMap α (Nat × ListM m α) compare × List α) :=
fun s => do
  match s.min with
  | none => failure
  | some (a, (n, L)) =>
    match ← uncons L with
    | none => pure (s.erase a, [])
    | some (b, L') => do
      let s' := s.insert a (n, L') |>.insert b (n + 1, f (n+1) b)
      let s' := match maxQueued with
      | some q => if s'.size > q then
          match s'.max with | some x => s'.erase x.1 | none => unreachable!
        else
          s'
      | none => s'
      pure (s', [b])

/--
A lazy list recording the best first search of a graph generated by a function
`f : α → ListM m α`.

We maintain a priority queue of visited-but-not-exhausted nodes,
and at each step take the next child of the highest priority node in the queue.

The option `maxDepth` limits the search depth.

The option `maxQueued` bounds the size of the priority queue,
discarding the lowest priority nodes as needed.
This implements a "beam" search, which may be incomplete but uses bounded memory.

The option `removeDuplicates` keeps an `RBSet` of previously visited nodes.
Otherwise, if the graph is not a tree then nodes will be visited multiple times.
-/
unsafe def bestFirstSearch (f : α → ListM m α) (a : α)
    (maxDepth : Option Nat := none) (maxQueued : Option Nat := none) (removeDuplicates := true) :
    ListM m α :=
let f := match maxDepth with
| none => fun _ a => f a
| some d => fun n a => if d < n then empty else f a
if removeDuplicates then
  let f' : Nat → α → ListM (StateT.{u} (RBSet α compare) m) α := fun n a =>
    (f n a).liftM >>= fun b => do
      let s ← get
      if s.contains b then failure
      set <| s.insert b
      pure b
  cons (do pure (some a, fixl (bestFirstSearchAux f' maxQueued) (RBMap.single a (0, f' 0 a))))
    |>.runState' (RBSet.empty.insert a)
else
  cons do pure (some a, fixl (bestFirstSearchAux f maxQueued) (RBMap.single a (0, f 0 a)))
